{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Merge Datasets\n",
    "\n",
    "This recipe demonstrates a simple pattern for merging FiftyOne Datasets via [Dataset.merge_samples()](https://voxel51.com/docs/fiftyone/api/fiftyone.core.dataset.html?highlight=merge_samples#fiftyone.core.dataset.Dataset.merge_samples).\n",
    "\n",
    "Merging datasets is an easy way to:\n",
    "\n",
    "-   Combine multiple datasets with information about the same underlying raw media (images and videos)\n",
    "-   Add model predictions to a FiftyOne dataset, to compare with ground truth annotations and/or other models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "In this recipe, we'll work with a dataset downloaded from the [FiftyOne Dataset Zoo](https://voxel51.com/docs/fiftyone/user_guide/dataset_creation/zoo.html).\n",
    "\n",
    "To access the dataset, install `torch` and `torchvision`, if necessary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Modify as necessary (e.g., GPU install). See https://pytorch.org for options\n",
    "!pip install torch\n",
    "!pip install torchvision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then download the test split of [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Split 'test' already downloaded\r\n"
     ]
    }
   ],
   "source": [
    "# Download the validation split of COCO-2017\n",
    "!fiftyone zoo download cifar10 --splits test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Merging model predictions\n",
    "\n",
    "Load the test split of CIFAR-10 into FiftyOne:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Split 'test' already downloaded\n",
      "Loading 'cifar10' split 'test'\n",
      " 100% |████████████████████████████████████████████████| 10000/10000 [13.0s elapsed, 0s remaining, 790.0 samples/s]      \n",
      "Name:           merge-example\n",
      "Media type      image\n",
      "Num samples:    10000\n",
      "Persistent:     False\n",
      "Info:           {'classes': ['airplane', 'automobile', 'bird', ...]}\n",
      "Tags:           ['test']\n",
      "Sample fields:\n",
      "    media_type:   fiftyone.core.fields.StringField\n",
      "    filepath:     fiftyone.core.fields.StringField\n",
      "    tags:         fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)\n",
      "    metadata:     fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.Metadata)\n",
      "    ground_truth: fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Classification)\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import os\n",
    "\n",
    "import fiftyone as fo\n",
    "import fiftyone.zoo as foz\n",
    "\n",
    "# Load test split of CIFAR-10\n",
    "dataset = foz.load_zoo_dataset(\"cifar10\", split=\"test\", dataset_name=\"merge-example\")\n",
    "classes = dataset.info[\"classes\"]\n",
    "\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset contains ground truth labels in its `ground_truth` field:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Sample: {\n",
      "    'id': '5f778dd116c859ba8e9a0113',\n",
      "    'media_type': 'image',\n",
      "    'filepath': '/Users/Brian/fiftyone/cifar10/test/data/000001.jpg',\n",
      "    'tags': BaseList(['test']),\n",
      "    'metadata': None,\n",
      "    'ground_truth': <Classification: {\n",
      "        'id': '5f778dd116c859ba8e9a0112',\n",
      "        'label': 'cat',\n",
      "        'confidence': None,\n",
      "        'logits': None,\n",
      "    }>,\n",
      "}>\n"
     ]
    }
   ],
   "source": [
    "# Print a sample from the dataset\n",
    "print(dataset.first())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Suppose you would like to add model predictions to some samples from the dataset.\n",
    "\n",
    "The usual way to do this is to just iterate over the dataset and add your predictions directly to the samples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_inference(filepath):\n",
    "    # Run inference on `filepath` here.\n",
    "    # For simplicity, we'll just generate a random label\n",
    "    label = random.choice(classes)\n",
    "    \n",
    "    return fo.Classification(label=label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Name:           merge-example\n",
      "Media type      image\n",
      "Num samples:    10000\n",
      "Persistent:     False\n",
      "Info:           {'classes': ['airplane', 'automobile', 'bird', ...]}\n",
      "Tags:           ['test']\n",
      "Sample fields:\n",
      "    media_type:   fiftyone.core.fields.StringField\n",
      "    filepath:     fiftyone.core.fields.StringField\n",
      "    tags:         fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)\n",
      "    metadata:     fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.Metadata)\n",
      "    ground_truth: fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Classification)\n",
      "    prediction:   fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Classification)\n"
     ]
    }
   ],
   "source": [
    "# Choose 100 samples at random\n",
    "random_samples = dataset.take(100)\n",
    "\n",
    "# Add model predictions to dataset\n",
    "for sample in random_samples:\n",
    "    sample[\"prediction\"] = run_inference(sample.filepath)\n",
    "    sample.save()\n",
    "\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, suppose you store the predictions in a separate dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Name:           2020.10.02.16.30.42\n",
      "Media type      image\n",
      "Num samples:    100\n",
      "Persistent:     False\n",
      "Info:           {}\n",
      "Tags:           []\n",
      "Sample fields:\n",
      "    media_type: fiftyone.core.fields.StringField\n",
      "    filepath:   fiftyone.core.fields.StringField\n",
      "    tags:       fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)\n",
      "    metadata:   fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.Metadata)\n",
      "    prediction: fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Classification)\n"
     ]
    }
   ],
   "source": [
    "# Filepaths of images to proces\n",
    "filepaths = [s.filepath for s in dataset.take(100)]\n",
    "\n",
    "# Run inference\n",
    "predictions = fo.Dataset()\n",
    "for filepath in filepaths:\n",
    "    sample = fo.Sample(filepath=filepath)\n",
    "    \n",
    "    sample[\"prediction\"] = run_inference(filepath)\n",
    "\n",
    "    predictions.add_sample(sample)\n",
    "\n",
    "print(predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can easily merge the `predictions` dataset into the main dataset via [Dataset.merge_samples()](https://voxel51.com/docs/fiftyone/api/fiftyone.core.dataset.html?highlight=merge_samples#fiftyone.core.dataset.Dataset.merge_samples).\n",
    "\n",
    "Let's start by loading a fresh copy of CIFAR-10 that doesn't have predictions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Split 'test' already downloaded\n",
      "Loading 'cifar10' split 'test'\n",
      " 100% |████████████████████████████████████████████████| 10000/10000 [12.7s elapsed, 0s remaining, 775.6 samples/s]      \n"
     ]
    }
   ],
   "source": [
    "dataset2 = foz.load_zoo_dataset(\"cifar10\", split=\"test\", dataset_name=\"merge-example2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now merge the predictions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 100% |████████████████████████████████████████████████████| 100/100 [288.3ms elapsed, 0s remaining, 346.9 samples/s]      \n",
      "Dataset:        merge-example2\n",
      "Num samples:    100\n",
      "Tags:           ['test']\n",
      "Sample fields:\n",
      "    media_type:   fiftyone.core.fields.StringField\n",
      "    filepath:     fiftyone.core.fields.StringField\n",
      "    tags:         fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)\n",
      "    metadata:     fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.Metadata)\n",
      "    ground_truth: fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Classification)\n",
      "    prediction:   fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Classification)\n",
      "Pipeline stages:\n",
      "    1. Exists(field='prediction', bool=True)\n"
     ]
    }
   ],
   "source": [
    "# Merge predictions\n",
    "dataset2.merge_samples(predictions)\n",
    "\n",
    "# Verify that 100 samples in `dataset2` now have predictions\n",
    "print(dataset2.exists(\"prediction\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, samples with the same absolute `filepath` are merged. However, you can customize this as desired via various keyword arguments of [Dataset.merge_samples()](https://voxel51.com/docs/fiftyone/api/fiftyone.core.dataset.html?highlight=merge_samples#fiftyone.core.dataset.Dataset.merge_samples).\n",
    "\n",
    "For example, the command below will merge samples with the same base filename, ignoring the directory:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 100% |████████████████████████████████████████████████████| 100/100 [293.0ms elapsed, 0s remaining, 341.3 samples/s]      \n"
     ]
    }
   ],
   "source": [
    "# Merge predictions, using the base filename of the samples to decide which samples to merge\n",
    "# In this case, we've already performed the merge, so the existing data is overwritten\n",
    "dataset.merge_samples(predictions, key_fcn=lambda p: os.path.basename(p))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's print a sample with predictions to verify that the merge happened as expected:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<SampleView: {\n",
      "    'id': '5f778df916c859ba8e9a79d1',\n",
      "    'media_type': 'image',\n",
      "    'filepath': '/Users/Brian/fiftyone/cifar10/test/data/000103.jpg',\n",
      "    'tags': BaseList(['test']),\n",
      "    'metadata': None,\n",
      "    'ground_truth': <Classification: {\n",
      "        'id': '5f778df916c859ba8e9a79d0',\n",
      "        'label': 'frog',\n",
      "        'confidence': None,\n",
      "        'logits': None,\n",
      "    }>,\n",
      "    'prediction': <Classification: {\n",
      "        'id': '5f778df216c859ba8e9a77cb',\n",
      "        'label': 'frog',\n",
      "        'confidence': None,\n",
      "        'logits': None,\n",
      "    }>,\n",
      "}>\n"
     ]
    }
   ],
   "source": [
    "# Print a sample with predictions\n",
    "print(dataset2.exists(\"prediction\").first())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
